package data_deepprocessing.algorithm.crfs;

import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.springframework.context.ApplicationContext;
import org.springframework.context.support.ClassPathXmlApplicationContext;

import data_deepprocessing.algorithm.evaluation.EvaluateResult;
import data_deepprocessing.algorithm.evaluation.ResultBean;
import data_deepprocessing.util.TxtOperate;




/**
 * @author YUANYUHU
 *	这里是CRF的十重交叉验证，同样的稍加修改就可以成为Bootstrapping的十重交叉验证，主要是形成训练集和测试集
 *  这里需要改进的地方（在main中给出例子）：
 *  1：路径的给出，我赞成的是给出chunkid,然后分组test & train
 *  2：一组标准数据，一组训练数据
 *  3: 再给出result path ,evaluation path 即可
 *  4：就是CRFs自身的改进，主要就是找出一些特征出来
 *  
 *  
 *  Bootstrapping这部分的交叉验证还不是很好的能够搞定，因为还要跟C-Value和NC-Value这部分的数据结合。
 *  这部分的内容实现起来也不是很难，回头马上就可以做了。
 *  现在不能进行十重交叉验证把，分为两类的数据
 */
public class CRFsCrossValid_tool {
	
	private CRFsSingleVersion  crfsSingleVersion;
	
	
	public CRFsSingleVersion getCrfsSingleVersion() {
		return crfsSingleVersion;
	}

	public void setCrfsSingleVersion(CRFsSingleVersion crfsSingleVersion) {
		this.crfsSingleVersion = crfsSingleVersion;
	}


	static{
		 //模板1
//		 String path = "D:\\yyh_yuanyuhu_graduation_experimental\\CRFs\\CRFs不同的模板下的表现\\ExperimentDataSet\\YuanYuhuExperimentDataSet_crossvalid_1\\";
		 //模板2
//		 String path = "D:\\yyh_yuanyuhu_graduation_experimental\\CRFs\\CRFs不同的模板下的表现\\ExperimentDataSet\\YuanYuhuExperimentDataSet_crossvalid_2\\";
		//模板3
		 String path = "D:\\yyh_yuanyuhu_graduation_experimental\\CRFs\\CRFs不同的模板下的表现\\ExperimentDataSet\\YuanYuhuExperimentDataSet_crossvalid_3\\";
		//模板4
//		 String path = "D:\\yyh_yuanyuhu_graduation_experimental\\CRFs\\CRFs不同的模板下的表现\\ExperimentDataSet\\YuanYuhuExperimentDataSet_crossvalid_4\\";
		 dirName = path+"dirpath";
		 dicPath = System.getProperty("user.dir")+"\\crf_dic";
		 modelPath = path+"model";
		 resultPath = path+"result";
		 evaluatePath = path+"evaluate";
		 k = 10;
		 iterations = 30;
		 algorithm = "crfs";
	}
	
	
	
	public static void main(String[] args){
		ApplicationContext context = new ClassPathXmlApplicationContext("applicationContext.xml");
		CRFsCrossValid_tool service = (CRFsCrossValid_tool) context.getBean("cRFsCrossValid_tool");
		try {
			service.function();
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
		

	private final static Integer k;
	private final static Integer iterations;
	private final static String dirName;
	private final static String dicPath;
	private final static String modelPath;
	private final static String resultPath;
	private final static String evaluatePath;
	private final static String algorithm;
	
	private Map<Integer, List<File>> divided_dataset = initMap();
	private List<File> trainList = new ArrayList<>();
	private List<File> testList = new ArrayList<>();

	
	private Map<Integer, List<File>> initMap(){
		Map<Integer, List<File>> divided_dataset = new HashMap<>();
		for (int i = 0; i < k; i++) {
			List<File> fold = new ArrayList<File>();
			divided_dataset.put(i, fold);
		}
		return divided_dataset;
	}
	
	public void function(){
		divideDataset();
		try {
			crossValidation();
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
	
	
	
	private void divideDataset(){
		File dir = new File(dirName);
		File[] files = dir.listFiles();
		List<File> fileList = new ArrayList<>();
		for(File file : files){
			fileList.add(file);
		}
		
		while (!fileList.isEmpty()) {
			int size = fileList.size();
			int index = (int) (Math.random() * size);
			File tmp = fileList.get(index);
			divided_dataset.get(size % k).add(tmp);
			fileList.remove(index);
		}
		
	}
	

	
	
	
	private void crossValidation() throws Exception{
		//用于平均后求平均
		double precsymsum=0.0;
		double recallsymsum=0.0;
		double f1symsum=0.0;
		double precindsum=0.0;
		double recallindsum=0.0;
		double f1indsum=0.0;
		double precdissum=0.0;
		double recalldissum=0.0;
		double f1dissum=0.0;
		//用于求实体总数后在求平均的。
		double presymtotal=0.0;
		double presymBtotal=0.0;
		double resymtotal=0.0;
		double f1symtotal=0.0;
		double predistotal=0.0;
		double predisBtotal=0.0;
		double redistotal=0.0;
		double f1distotal=0.0;
		double preindtotal=0.0;
		double preindBtotal=0.0;
		double reindtotal=0.0;
		double f1indtotal=0.0;
		
		for(int datasetCode = 0; datasetCode<k; ++datasetCode){
			testList = divided_dataset.get(datasetCode);
			int count = 0;
			while (count < k){
				if(count != datasetCode){
					trainList.addAll(divided_dataset.get(datasetCode));
				}
				++count;
			}
			
			
			if(algorithm.equalsIgnoreCase("CRFs")){
				
				crfsSingleVersion.callListCRFs(trainList, testList, resultPath+"/result"+datasetCode+"/", modelPath, dicPath,iterations);
			}
//			else if(algorithm.equals("Bootstrapping")){
//				
//			}
			
			File evaluate=TxtOperate.newTxt(evaluatePath, "evaluate"+datasetCode);
			
			
			//每次验证后的评价
			double precsymsumone=0.0;
			double precsymBsumone=0.0;
			double recallsymsumone=0.0;
			double f1symsumone=0.0;
			double precindsumone=0.0;
			double precindBsumone=0.0;
			double recallindsumone=0.0;
			double f1indsumone=0.0;
			double precdissumone=0.0;
			double precdisBsumone=0.0;
			double recalldissumone=0.0;
			double f1dissumone=0.0;
			double correctS=0;
			double correctD=0;
			double correctI=0;
			double extractS=0;
			double extractD=0;
			double extractI=0;
			double totalS=0;
			double totalD=0;
			double totalI=0;
			int totalSFile = testList.size();
			int totalDFile = testList.size();
			int totalIFile = testList.size();
			double totalSpre=0.0;
			double totalSBpre=0.0;
			double totalSrecall=0.0;
			double totalSf1=0.0;
			double totalDpre=0.0;
			double totalDBpre=0.0;
			double totalDrecall=0.0;
			double totalDf1=0.0;
			double totalIpre=0.0;
			double totalIBpre=0.0;
			double totalIrecall=0.0;
			double totalIf1=0.0;
			double correctSB=0.0;
			double correctDB=0.0;
			double correctIB=0.0;
			
			for (File file  : testList) {
				String testFileName = file.getName();
				//EvaluateResult er = new EvaluateResult(dirName+"/"+v, resultPath+"/result"+i+"/"+v,evaluate);
				//这部分的内容最好修改一下，就好多了，准备两份份数据，所有数据的训练集和测试集，然后评价就完事了
				EvaluateResult er = new EvaluateResult(resultPath+"/result"+datasetCode+"/"+testFileName,dirName+"/"+testFileName,evaluate);
				
				try {
					TxtOperate.writeTxtFile("文件编号:"+testFileName+"\r\n", evaluate, true);
				} catch (Exception e1) {
					// TODO Auto-generated catch block
					e1.printStackTrace();
				}
				String[] tagList = {"S","D","I"};
				for(String tag: tagList) {
					ResultBean result = er.evaluateResult(tag);
					
					double precesion=result.getAccuracy();							
					double recall=result.getRecall();
					double f1=result.getF1();
					if(tag.equals("S")){
//						if((precesion==0.0)&&(recall==0.0)&&(f1==0.0)){
//							totalSFile--;
//							continue;
//						}else{
						    correctSB=correctSB+result.getCorrectB();
							correctS=correctS+result.getCorrect();
							extractS=extractS+result.getExtract();
							totalS=totalS+result.getTotal();
							totalSpre=totalSpre+precesion;
							totalSrecall=totalSrecall+recall;
							totalSf1=totalSf1+f1;
//						}
						//System.out.println(result.getF1()+"meigeeee");
					}else if(tag.equals("D")){
						if((precesion==0.0)&&(recall==0.0)&&(f1==0.0)){
							totalDFile--;
							continue;
						}else{	
							correctDB=correctDB+result.getCorrectB();
							correctD=correctD+result.getCorrect();
							extractD=extractD+result.getExtract();
							totalD=totalD+result.getTotal();
							totalDpre=totalDpre+precesion;
							totalDrecall=totalDrecall+recall;
							totalDf1=totalDf1+f1;
						}
					}else if(tag.equals("I")){
						if((precesion==0.0)&&(recall==0.0)&&(f1==0.0)){
							totalIFile--;
							continue;
						}else{
							correctIB=correctIB+result.getCorrectB();
							correctI=correctI+result.getCorrect();
							extractI=extractI+result.getExtract();
							totalI=totalI+result.getTotal();
							totalIpre=totalIpre+precesion;
							totalIrecall=totalIrecall+recall;
							totalIf1=totalIf1+f1;
						}
					}
					//System.out.println(tag+"= accuray: "+result.getAccuracy()+" recall: "+result.getRecall()+" F1: "+result.getF1());
				}
			}
			//这里是计算每个病例块的精度
			precsymsumone=precsymsumone+(totalSpre/totalSFile);
			precsymBsumone=precsymBsumone+(totalSBpre/totalSFile);
			recallsymsumone=recallsymsumone+(totalSrecall/totalSFile);
			f1symsumone=f1symsumone+(totalSf1/totalSFile);
			precindsumone=precindsumone+(totalIpre/totalIFile);
			precindBsumone=precindBsumone+(totalIBpre/totalIFile);
			recallindsumone=recallindsumone+(totalIrecall/totalIFile);
			f1indsumone=f1indsumone+(totalIf1/totalIFile);
			precdissumone=precdissumone+(totalDpre/totalDFile);
			precdisBsumone=precdisBsumone+(totalDBpre/totalDFile);
			recalldissumone=recalldissumone+(totalDrecall/totalDFile);
			f1dissumone=f1dissumone+(totalDf1/totalDFile);
			//计算每一次验证的精度等
			precsymsum=precsymsum+(precsymsumone/totalSFile);
			recallsymsum=recallsymsum+(recallsymsumone/totalSFile);
			f1symsum=f1symsum+(f1symsumone/totalSFile);
			precindsum=precindsum+(precindsumone/totalIFile);
			recallindsum=recallindsum+(recallindsumone/totalIFile);
			f1indsum=f1indsum+(f1indsumone/totalIFile);
			precdissum=precdissum+(precdissumone/totalDFile);
			recalldissumone=recalldissum+(recalldissumone/totalDFile);
			f1dissum=f1dissum+(f1dissumone/totalDFile);
			
			TxtOperate.writeTxtFile("============================================================================================="+"\r\n", evaluate, true);
			TxtOperate.writeTxtFile("以下为计算每个病例块的精度，Recall和F1值，然后求得平均值："+"\r\n", evaluate, true);	
			TxtOperate.writeTxtFile("症状的Precious= "+precsymsumone+"-----症状的Recall="+recallsymsumone+"-----症状的F1="+f1symsumone+"\r\n", evaluate, true);
			TxtOperate.writeTxtFile("疾病的Precious= "+precdissumone+"-----症状的Recall="+recalldissumone+"-----症状的F1="+f1dissumone+"\r\n", evaluate, true);
			TxtOperate.writeTxtFile("诱因的Precious= "+precindsumone+"-----诱因的Recall="+recallindsumone+"-----症状的F1="+f1indsumone+"\r\n", evaluate, true);							
			TxtOperate.writeTxtFile("=============================================================================================="+"\r\n", evaluate, true);			
			TxtOperate.writeTxtFile("抽取出的症状的实体总和："+extractS+"-----正确抽取出的症状的实体总和："+correctS+"-----正确抽取出的症状B的实体总和："+correctSB+"-----标准病历中症状的实体总和："+totalS+"\r\n", evaluate, true);	
			TxtOperate.writeTxtFile("抽取出的疾病的实体总和："+extractD+"-----正确抽取出的疾病的实体总和："+correctD+"-----正确抽取出的疾病B的实体总和："+correctDB+"-----标准病历中疾病的实体总和："+totalD+"\r\n", evaluate, true);	
			TxtOperate.writeTxtFile("抽取出的诱因的实体总和："+extractI+"-----正确抽取出的诱因的实体总和："+correctI+"-----正确抽取出的诱因B的实体总和："+correctIB+"-----标准病历中诱因的实体总和："+totalI+"\r\n", evaluate, true);	
			TxtOperate.writeTxtFile("=============================================================================================="+"\r\n", evaluate, true);			
			double precisionAllS=correctS/extractS;
			double precisionAllSB=correctSB/extractS;
			double precisionAllD=correctD/extractD;
			double precisionAllDB=correctDB/extractD;
			double precisionAllI=correctI/extractI;
			double precisionAllIB=correctIB/extractI;
			double recallAllS=correctS/totalS;
			double recallAllD=correctD/totalD;
			double recallAllI=correctI/totalI;
			double f1AllS=(2*precisionAllS*recallAllS)/(precisionAllS+recallAllS);
			double f1AllD=(2*precisionAllD*recallAllD)/(precisionAllD+recallAllD);
			double f1AllI=(2*precisionAllI*recallAllI)/(precisionAllI+recallAllI);
			presymtotal=presymtotal+precisionAllS;
			presymBtotal=presymBtotal+precisionAllSB;
			resymtotal=resymtotal+recallAllS;
			f1symtotal=f1symtotal+f1AllS;
			predistotal=predistotal+precisionAllD;
			predisBtotal=predisBtotal+precisionAllDB;
			redistotal=redistotal+recallAllD;
			f1distotal=f1distotal+f1AllD;
			preindtotal=preindtotal+precisionAllI;
			preindBtotal=preindBtotal+precisionAllIB;
			reindtotal=reindtotal+recallAllI;
			f1indtotal=f1indtotal+f1AllI;
			TxtOperate.writeTxtFile("症状B的精度="+precisionAllSB+"-----症状的精度="+precisionAllS+"-----症状的recall="+recallAllS+"-----症状的F1值="+f1AllS+"\r\n", evaluate, true);
			TxtOperate.writeTxtFile("疾病B的精度="+precisionAllDB+"-----疾病的精度="+precisionAllD+"-----疾病的recall="+recallAllD+"-----疾病的F1值="+f1AllD+"\r\n", evaluate, true);	
			TxtOperate.writeTxtFile("诱因B的精度="+precisionAllIB+"-----诱因的精度="+precisionAllI+"-----诱因的recall="+recallAllI+"-----诱因的F1值="+f1AllI+"\r\n", evaluate, true);	

			
		}
		File standerFile=TxtOperate.newTxt(evaluatePath, "stander");
		TxtOperate.writeTxtFile("平均后总和在求得k重交叉验证的平均："+"\r\n", standerFile, true);
		TxtOperate.writeTxtFile("==============================================================================================================="+"\r\n", standerFile, true);
		TxtOperate.writeTxtFile("症状的精度="+precsymsum/k+"-----症状的recall="+recallsymsum/k+"-----症状的F1值="+f1symsum/k+"\r\n", standerFile, true);
		TxtOperate.writeTxtFile("疾病的精度="+precdissum/k+"-----疾病的recall="+recalldissum/k+"-----疾病的F1值="+f1dissum/k+"\r\n", standerFile, true);	
		TxtOperate.writeTxtFile("诱因的精度="+precindsum/k+"-----诱因的recall="+recallindsum/k+"-----诱因的F1值="+f1indsum/k+"\r\n", standerFile, true);
		TxtOperate.writeTxtFile("==============================================================================================================="+"\r\n", standerFile, true);
		TxtOperate.writeTxtFile("计算正确实体后，再求得k重交叉验证的平均："+"\r\n", standerFile, true);
		TxtOperate.writeTxtFile("症状B的精度="+presymBtotal/k+"-----症状的精度="+presymtotal/k+"-----症状的recall="+resymtotal/k+"-----症状的F1值="+f1symtotal/k+"\r\n", standerFile, true);
		TxtOperate.writeTxtFile("疾病B的精度="+predisBtotal/k+"-----疾病的精度="+predistotal/k+"-----疾病的recall="+redistotal/k+"-----疾病的F1值="+f1distotal/k+"\r\n", standerFile, true);	
		TxtOperate.writeTxtFile("诱因B的精度="+preindBtotal/k+"-----诱因的精度="+preindtotal/k+"-----诱因的recall="+reindtotal/k+"-----诱因的F1值="+f1indtotal/k+"\r\n", standerFile, true);	
		
		
	}
	
	
}








